"""
Phase 5: Interaction Tests
============================
1. old_dep × exp_rev_gap → downgrade
2. old_dep × debt/GDP → downgrade
3. r_minus_g × old_dep → downgrade
4. global_oadr × exp_rev_gap → downgrade
5. oadr_spline_20 × debt/GDP → downgrade

Both logit and PanelGLS frameworks.

Output: table8_interactions.md, phase5_results.csv
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.optimize import minimize
from scipy import stats as sp_stats
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_logit_simple(df, y_var, x_vars, label, min_events=5):
    """Simplified logit returning dict with key results."""
    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 30:
        return None

    y = sub[y_var].values.astype(float)
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values.astype(float)])
    n, k = X.shape

    if y.sum() < min_events:
        return None

    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_ll(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def grad(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_std.T @ (y - p)

    try:
        result = minimize(neg_ll, np.zeros(k), jac=grad,
                          method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        beta_std = result.x
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        z = X @ beta
        z = np.clip(z, -30, 30)
        p_hat = 1 / (1 + np.exp(-z))
        p_hat = np.clip(p_hat, 1e-12, 1 - 1e-12)

        W = p_hat * (1 - p_hat)
        H = X.T @ (X * W[:, None])
        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.diag(V))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(beta / se)))

        ll_model = -result.fun
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) + (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

    except Exception:
        return None

    res = {'model': label, 'estimator': 'logit', 'n_obs': n,
           'n_countries': sub['iso3'].nunique(), 'n_events': int(y.sum()),
           'pseudo_r2': pseudo_r2}

    print(f"\n  {label} (N={n}, events={int(y.sum())}, Pseudo-R²={pseudo_r2:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:35s} β={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}")
        res[f'{name}_beta'] = beta[i + 1]
        res[f'{name}_se'] = se[i + 1]
        res[f'{name}_p'] = pvalues[i + 1]

    return res


def fit_gls(df, dep_var, x_vars, label):
    """Fit PanelGLS and return results."""
    cols = [dep_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 100:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None, None

    model = PanelGLS()
    model.fit(sub[dep_var].values, sub[x_vars].values,
              sub['iso3'].values, sub['year'].values)
    print(f"\n  {label} (N={model.n_obs}, R²={model.r_squared:.4f})")
    model.summary(feature_names=x_vars)

    rdf = model.to_dataframe(feature_names=x_vars)
    rdf['model'] = label
    rdf['n_obs'] = model.n_obs
    rdf['n_countries'] = model.n_countries
    rdf['r_squared'] = model.r_squared
    return model, rdf


def write_markdown_table(path, title, headers, rows, notes=None):
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 5: Interaction Tests")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded: {df['iso3'].nunique()} countries, {len(df):,} obs")

    all_logit = []
    all_gls = []
    table_rows = []

    controls = ['rgdp_growth']

    # Create interaction terms
    interactions = [
        ('old_dep_x_exp_rev_gap', 'old_dep', 'exp_rev_gap'),
        ('old_dep_x_debt', 'old_dep', 'govt_debt_gdp'),
        ('r_minus_g_x_old_dep', 'r_minus_g', 'old_dep'),
        ('oadr_spline_20_x_debt', 'oadr_spline_20', 'govt_debt_gdp'),
    ]
    if 'global_oadr' in df.columns:
        interactions.append(('global_oadr_x_exp_rev_gap', 'global_oadr', 'exp_rev_gap'))

    for iname, var1, var2 in interactions:
        if var1 in df.columns and var2 in df.columns:
            df[iname] = df[var1] * df[var2]

    # ── Test 1: old_dep × exp_rev_gap ──
    print("\n" + "-" * 70)
    print("TEST 1: OADR × Expenditure-Revenue Gap")
    print("-" * 70)

    base_vars = ['old_dep', 'exp_rev_gap', 'govt_debt_gdp']
    base_avail = [v for v in base_vars if v in df.columns]

    # Logit on downgrade
    x_logit = base_avail + ['old_dep_x_exp_rev_gap'] + controls
    x_logit = [v for v in x_logit if v in df.columns]
    res = run_logit_simple(df, 'downgrade_any', x_logit, "1. OADR×Exp-Rev Gap → downgrade")
    if res:
        all_logit.append(res)
        ix_b = res.get('old_dep_x_exp_rev_gap_beta', 0)
        ix_p = res.get('old_dep_x_exp_rev_gap_p', 1)
        table_rows.append(["OADR × Exp-Rev Gap", "Logit→downgrade",
                           f"{ix_b:.4f}{stars(ix_p)}", f"({res.get('old_dep_x_exp_rev_gap_se',0):.4f})",
                           str(res['n_obs']), f"{res['pseudo_r2']:.3f}"])

    # PanelGLS on rating
    x_gls = base_avail + ['old_dep_x_exp_rev_gap'] + controls
    x_gls = [v for v in x_gls if v in df.columns]
    m, r = fit_gls(df, 'rating_numeric', x_gls, "1. OADR×Exp-Rev Gap → rating")
    if r is not None:
        all_gls.append(r)
        row = r[r['variable'] == 'old_dep_x_exp_rev_gap']
        if len(row) > 0:
            c, se, p = row.iloc[0]['coefficient'], row.iloc[0]['std_error'], row.iloc[0]['p_value']
            table_rows.append(["OADR × Exp-Rev Gap", "PanelGLS→rating",
                               f"{c:.4f}{stars(p)}", f"({se:.4f})",
                               f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ── Test 2: old_dep × debt/GDP ──
    print("\n" + "-" * 70)
    print("TEST 2: OADR × Debt/GDP")
    print("-" * 70)

    x_vars = ['old_dep', 'govt_debt_gdp', 'old_dep_x_debt'] + controls
    x_vars = [v for v in x_vars if v in df.columns]
    res = run_logit_simple(df, 'downgrade_any', x_vars, "2. OADR×Debt → downgrade")
    if res:
        all_logit.append(res)
        ix_b = res.get('old_dep_x_debt_beta', 0)
        ix_p = res.get('old_dep_x_debt_p', 1)
        table_rows.append(["OADR × Debt/GDP", "Logit→downgrade",
                           f"{ix_b:.4f}{stars(ix_p)}", f"({res.get('old_dep_x_debt_se',0):.4f})",
                           str(res['n_obs']), f"{res['pseudo_r2']:.3f}"])

    m, r = fit_gls(df, 'rating_numeric', x_vars, "2. OADR×Debt → rating")
    if r is not None:
        all_gls.append(r)
        row = r[r['variable'] == 'old_dep_x_debt']
        if len(row) > 0:
            c, se, p = row.iloc[0]['coefficient'], row.iloc[0]['std_error'], row.iloc[0]['p_value']
            table_rows.append(["OADR × Debt/GDP", "PanelGLS→rating",
                               f"{c:.4f}{stars(p)}", f"({se:.4f})",
                               f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ── Test 3: r_minus_g × old_dep ──
    print("\n" + "-" * 70)
    print("TEST 3: r-g × OADR")
    print("-" * 70)

    if 'r_minus_g' in df.columns:
        x_vars = ['old_dep', 'r_minus_g', 'r_minus_g_x_old_dep', 'govt_debt_gdp'] + controls
        x_vars = [v for v in x_vars if v in df.columns]

        res = run_logit_simple(df, 'downgrade_any', x_vars, "3. r-g×OADR → downgrade")
        if res:
            all_logit.append(res)
            ix_b = res.get('r_minus_g_x_old_dep_beta', 0)
            ix_p = res.get('r_minus_g_x_old_dep_p', 1)
            table_rows.append(["r-g × OADR", "Logit→downgrade",
                               f"{ix_b:.4f}{stars(ix_p)}", f"({res.get('r_minus_g_x_old_dep_se',0):.4f})",
                               str(res['n_obs']), f"{res['pseudo_r2']:.3f}"])

        m, r = fit_gls(df, 'rating_numeric', x_vars, "3. r-g×OADR → rating")
        if r is not None:
            all_gls.append(r)
            row = r[r['variable'] == 'r_minus_g_x_old_dep']
            if len(row) > 0:
                c, se, p = row.iloc[0]['coefficient'], row.iloc[0]['std_error'], row.iloc[0]['p_value']
                table_rows.append(["r-g × OADR", "PanelGLS→rating",
                                   f"{c:.4f}{stars(p)}", f"({se:.4f})",
                                   f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ── Test 4: global_oadr × exp_rev_gap ──
    print("\n" + "-" * 70)
    print("TEST 4: Global OADR × Exp-Rev Gap")
    print("-" * 70)

    if 'global_oadr_x_exp_rev_gap' in df.columns:
        x_vars = ['old_dep', 'exp_rev_gap', 'global_oadr', 'global_oadr_x_exp_rev_gap'] + controls
        x_vars = [v for v in x_vars if v in df.columns]

        res = run_logit_simple(df, 'downgrade_any', x_vars,
                               "4. Global OADR×Exp-Rev Gap → downgrade")
        if res:
            all_logit.append(res)
            ix_b = res.get('global_oadr_x_exp_rev_gap_beta', 0)
            ix_p = res.get('global_oadr_x_exp_rev_gap_p', 1)
            table_rows.append(["Global OADR × Exp-Rev Gap", "Logit→downgrade",
                               f"{ix_b:.4f}{stars(ix_p)}",
                               f"({res.get('global_oadr_x_exp_rev_gap_se',0):.4f})",
                               str(res['n_obs']), f"{res['pseudo_r2']:.3f}"])

        m, r = fit_gls(df, 'rating_numeric', x_vars,
                        "4. Global OADR×Exp-Rev Gap → rating")
        if r is not None:
            all_gls.append(r)
            row = r[r['variable'] == 'global_oadr_x_exp_rev_gap']
            if len(row) > 0:
                c, se, p = row.iloc[0]['coefficient'], row.iloc[0]['std_error'], row.iloc[0]['p_value']
                table_rows.append(["Global OADR × Exp-Rev Gap", "PanelGLS→rating",
                                   f"{c:.4f}{stars(p)}", f"({se:.4f})",
                                   f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ── Test 5: oadr_spline_20 × debt/GDP ──
    print("\n" + "-" * 70)
    print("TEST 5: OADR Spline(20%) × Debt/GDP")
    print("-" * 70)

    x_vars = ['old_dep', 'oadr_spline_20', 'govt_debt_gdp',
              'oadr_spline_20_x_debt'] + controls
    x_vars = [v for v in x_vars if v in df.columns]

    res = run_logit_simple(df, 'downgrade_any', x_vars,
                           "5. Spline(20%)×Debt → downgrade")
    if res:
        all_logit.append(res)
        ix_b = res.get('oadr_spline_20_x_debt_beta', 0)
        ix_p = res.get('oadr_spline_20_x_debt_p', 1)
        table_rows.append(["OADR Spline(20%) × Debt", "Logit→downgrade",
                           f"{ix_b:.4f}{stars(ix_p)}",
                           f"({res.get('oadr_spline_20_x_debt_se',0):.4f})",
                           str(res['n_obs']), f"{res['pseudo_r2']:.3f}"])

    m, r = fit_gls(df, 'rating_numeric', x_vars,
                    "5. Spline(20%)×Debt → rating")
    if r is not None:
        all_gls.append(r)
        row = r[r['variable'] == 'oadr_spline_20_x_debt']
        if len(row) > 0:
            c, se, p = row.iloc[0]['coefficient'], row.iloc[0]['std_error'], row.iloc[0]['p_value']
            table_rows.append(["OADR Spline(20%) × Debt", "PanelGLS→rating",
                               f"{c:.4f}{stars(p)}", f"({se:.4f})",
                               f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ── Write table ──
    if table_rows:
        write_markdown_table(
            TABLES_DIR / "table8_interactions.md",
            "Table 8: Interaction Tests — Demographics × Fiscal Variables",
            ["Interaction", "Framework", "Interaction Coef", "(SE)", "N", "R²/Pseudo-R²"],
            table_rows,
            notes="Tests whether demographic aging amplifies fiscal pressure on ratings. "
                  "Logit: downgrade_any. PanelGLS: rating_numeric."
        )

    # ── Save results ──
    if all_logit:
        pd.DataFrame(all_logit).to_csv(TABLES_DIR / "phase5_logit_results.csv", index=False)
    if all_gls:
        pd.concat(all_gls, ignore_index=True).to_csv(
            TABLES_DIR / "phase5_gls_results.csv", index=False)
    print(f"\n  Saved: phase5 results")

    print("\n" + "=" * 70)
    print("Phase 5 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
